import

import torch 
from fastai.vision.all import * 
import cv2 as cv
import fastbook
from fastbook import *
from fastai.vision.widgets import *

data

path=Path('/home/khy/chest_xray/chest_xray') 
path.ls()
(#5) [Path('/home/khy/chest_xray/chest_xray/train'),Path('/home/khy/chest_xray/chest_xray/test'),Path('/home/khy/chest_xray/chest_xray/chest_xray'),Path('/home/khy/chest_xray/chest_xray/__MACOSX'),Path('/home/khy/chest_xray/chest_xray/val')]
files=get_image_files(path)
files
(#11712) [Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0766-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/NORMAL2-IM-1318-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0160-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/NORMAL2-IM-1327-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0489-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0509-0001-0002.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0761-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0416-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/NORMAL2-IM-0566-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0411-0001.jpeg')...]
dls = ImageDataLoaders.from_folder(path, train='train', valid_pct=0.2, item_tfms=Resize(224))      
dls.vocab
['NORMAL', 'PNEUMONIA']
dls.show_batch(max_n=16)
learn=cnn_learner(dls,resnet34,metrics=error_rate)
net1=learn.model[0]
net2=learn.model[1] 
net2 = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d(output_size=1), 
    torch.nn.Flatten(),
    torch.nn.Linear(512,out_features=2,bias=False))
net=torch.nn.Sequential(net1,net2)
lrnr2=Learner(dls,net,metrics=accuracy) 
lrnr2.fine_tune(200) 
epoch train_loss valid_loss accuracy time
0 0.166842 0.091861 0.967122 00:33
epoch train_loss valid_loss accuracy time
0 0.076691 0.070642 0.973954 00:33
1 0.065596 0.065189 0.976943 00:33
2 0.063810 0.060881 0.977797 00:33
3 0.058133 0.055606 0.979505 00:33
4 0.047295 0.051751 0.982494 00:33
5 0.049507 0.061955 0.975235 00:33
6 0.040383 0.048890 0.982494 00:33
7 0.037072 0.038793 0.985483 00:33
8 0.029895 0.035411 0.988044 00:33
9 0.024122 0.032279 0.988471 00:33
10 0.022319 0.030799 0.990606 00:33
11 0.022883 0.029063 0.990606 00:33
12 0.018799 0.024217 0.993595 00:33
13 0.018655 0.026862 0.991887 00:33
14 0.017203 0.025556 0.991460 00:33
15 0.012168 0.028741 0.991887 00:33
16 0.013291 0.021540 0.991460 00:33
17 0.013113 0.023177 0.993595 00:33
18 0.014589 0.023715 0.993168 00:33
19 0.010889 0.027784 0.992314 00:33
20 0.010598 0.028819 0.992314 00:33
21 0.013652 0.023543 0.993168 00:33
22 0.010165 0.021542 0.993168 00:33
23 0.011329 0.024496 0.992314 00:33
24 0.009473 0.019847 0.992314 00:33
25 0.007470 0.022198 0.990179 00:33
26 0.007615 0.017968 0.995303 00:33
27 0.006131 0.022273 0.995730 00:33
28 0.008292 0.032437 0.992314 00:33
29 0.008912 0.042545 0.988898 00:33
30 0.009870 0.039163 0.988044 00:33
31 0.010967 0.018784 0.992314 00:33
32 0.006510 0.021688 0.991887 00:33
33 0.006636 0.033374 0.991460 00:33
34 0.010336 0.020198 0.993595 00:33
35 0.009317 0.030448 0.991033 00:33
36 0.007046 0.022307 0.993168 00:33
37 0.009590 0.026956 0.990606 00:33
38 0.006269 0.055886 0.985056 00:33
39 0.010013 0.018850 0.994449 00:33
40 0.008058 0.027818 0.993168 00:33
41 0.007327 0.015476 0.993595 00:33
42 0.006886 0.010855 0.997011 00:33
43 0.011692 0.017141 0.997011 00:33
44 0.007462 0.030888 0.990179 00:33
45 0.006464 0.015794 0.992741 00:33
46 0.007760 0.068463 0.984628 00:33
47 0.006637 0.015711 0.993168 00:33
48 0.010105 0.041067 0.988898 00:33
49 0.007672 0.012651 0.996157 00:33
50 0.014199 0.083004 0.974381 00:33
51 0.012289 0.018203 0.993168 00:33
52 0.009026 0.020449 0.994022 00:33
53 0.004553 0.017501 0.993595 00:33
54 0.010326 0.024923 0.991033 00:33
55 0.015319 0.027962 0.992314 00:33
56 0.004357 0.023815 0.994022 00:33
57 0.005287 0.019874 0.992314 00:33
58 0.009573 0.014026 0.995730 00:33
59 0.006735 0.021964 0.993168 00:33
60 0.005811 0.023319 0.990606 00:33
61 0.011406 0.026691 0.992741 00:33
62 0.005277 0.022868 0.994449 00:33
63 0.006119 0.018390 0.994022 00:33
64 0.007875 0.034545 0.994022 00:33
65 0.005800 0.020408 0.994022 00:33
66 0.002680 0.019692 0.994449 00:33
67 0.006419 0.034546 0.991033 00:33
68 0.006348 0.053590 0.986763 00:33
69 0.005590 0.031790 0.993595 00:33
70 0.007865 0.029411 0.994876 00:33
71 0.002760 0.026847 0.993168 00:33
72 0.009839 0.030372 0.992741 00:33
73 0.008680 0.026388 0.992314 00:33
74 0.004330 0.031201 0.992741 00:33
75 0.009632 0.078810 0.984202 00:33
76 0.003771 0.022387 0.992741 00:33
77 0.006113 0.030133 0.992314 00:33
78 0.003496 0.028839 0.995303 00:33
79 0.003018 0.026174 0.994022 00:33
80 0.007461 0.030011 0.993595 00:33
81 0.004392 0.023791 0.994876 00:33
82 0.005972 0.068508 0.987617 00:33
83 0.006191 0.019870 0.996584 00:33
84 0.005330 0.020402 0.996584 00:33
85 0.002982 0.036186 0.993168 00:33
86 0.003956 0.019152 0.994022 00:33
87 0.006709 0.022051 0.994449 00:33
88 0.004887 0.043770 0.991460 00:33
89 0.004027 0.025353 0.993168 00:33
90 0.002959 0.029085 0.992741 00:33
91 0.003077 0.025070 0.993595 00:33
92 0.004699 0.024857 0.992741 00:33
93 0.002660 0.032952 0.995730 00:33
94 0.003100 0.025073 0.994876 00:33
95 0.002563 0.023130 0.994022 00:33
96 0.001407 0.023987 0.995730 00:33
97 0.002879 0.015754 0.996584 00:33
98 0.002273 0.019964 0.995730 00:33
99 0.001539 0.023395 0.994022 00:33
100 0.002776 0.019369 0.997438 00:33
101 0.001925 0.015023 0.996157 00:33
102 0.002006 0.039217 0.991887 00:33
103 0.003615 0.011737 0.997011 00:33
104 0.002477 0.016405 0.995730 00:33
105 0.001914 0.014328 0.997438 00:33
106 0.000848 0.020702 0.995730 00:33
107 0.005377 0.028292 0.994022 00:33
108 0.003150 0.019413 0.996584 00:33
109 0.001558 0.022858 0.995730 00:33
110 0.002981 0.022044 0.995730 00:33
111 0.003152 0.024832 0.993595 00:33
112 0.001988 0.016285 0.995730 00:33
113 0.000533 0.014695 0.995730 00:33
114 0.000902 0.017304 0.995730 00:33
115 0.001843 0.019725 0.995730 00:33
116 0.001038 0.020030 0.995730 00:33
117 0.000729 0.019264 0.994022 00:33
118 0.001277 0.027110 0.994876 00:33
119 0.001734 0.026816 0.993168 00:33
120 0.002050 0.020589 0.995730 00:33
121 0.002221 0.022525 0.995730 00:33
122 0.000572 0.027818 0.993168 00:33
123 0.001051 0.018991 0.994876 00:33
124 0.000295 0.019816 0.994876 00:33
125 0.001252 0.022995 0.995730 00:33
126 0.000770 0.021016 0.994449 00:33
127 0.000683 0.030154 0.994449 00:33
128 0.003303 0.026239 0.995730 00:33
129 0.001704 0.025088 0.994022 00:33
130 0.002516 0.010910 0.996584 00:33
131 0.000699 0.015325 0.996584 00:33
132 0.000870 0.013863 0.996584 00:33
133 0.000663 0.020103 0.995730 00:33
134 0.000980 0.012507 0.996584 00:33
135 0.000181 0.014895 0.995730 00:33
136 0.000645 0.030882 0.994022 00:33
137 0.000258 0.029726 0.994022 00:33
138 0.000154 0.019418 0.995730 00:33
139 0.000699 0.019971 0.995730 00:33
140 0.000355 0.024038 0.994876 00:33
141 0.000170 0.030813 0.994876 00:33
142 0.000657 0.027899 0.994876 00:33
143 0.001425 0.024708 0.995730 00:33
144 0.000381 0.020135 0.994022 00:34
145 0.000152 0.025634 0.994876 00:33
146 0.000075 0.018921 0.994876 00:33
147 0.000226 0.017673 0.994876 00:33
148 0.000224 0.023066 0.996584 00:33
149 0.000632 0.018082 0.994876 00:33
150 0.000625 0.016179 0.996584 00:33
151 0.000080 0.021201 0.994876 00:33
152 0.000068 0.021460 0.994022 00:33
153 0.000112 0.018794 0.995730 00:33
154 0.000080 0.021812 0.994876 00:33
155 0.000040 0.018293 0.995730 00:33
156 0.000171 0.018570 0.997438 00:33
157 0.000175 0.015313 0.996584 00:33
158 0.000464 0.016535 0.996584 00:33
159 0.000109 0.019572 0.996584 00:33
160 0.000062 0.021594 0.996584 00:33
161 0.000064 0.014384 0.996584 00:33
162 0.000014 0.020526 0.996584 00:33
163 0.000028 0.019420 0.995730 00:33
164 0.000042 0.030555 0.994876 00:33
165 0.000080 0.022019 0.996584 00:33
166 0.000079 0.030117 0.994876 00:33
167 0.000038 0.019891 0.996584 00:33
168 0.000027 0.024130 0.996584 00:33
169 0.000017 0.027270 0.995730 00:33
170 0.000032 0.018282 0.995730 00:33
171 0.000062 0.019155 0.996584 00:33
172 0.000059 0.023948 0.995730 00:33
173 0.000011 0.025428 0.995730 00:33
174 0.000011 0.019787 0.995730 00:33
175 0.000018 0.025644 0.995730 00:33
176 0.000185 0.021899 0.995730 00:33
177 0.000056 0.021866 0.995730 00:33
178 0.000061 0.022560 0.995730 00:33
179 0.000019 0.019159 0.995730 00:33
180 0.000009 0.024180 0.995730 00:33
181 0.000030 0.022470 0.995730 00:33
182 0.000007 0.020468 0.995730 00:33
183 0.000049 0.024680 0.995730 00:33
184 0.000009 0.019799 0.994876 00:33
185 0.000026 0.025008 0.995730 00:33
186 0.000028 0.029448 0.995730 00:33
187 0.000161 0.032871 0.995730 00:33
188 0.000334 0.028276 0.995730 00:33
189 0.000033 0.023425 0.995730 00:33
190 0.000012 0.027646 0.995730 00:33
191 0.000012 0.026857 0.995730 00:33
192 0.000120 0.025125 0.995730 00:33
193 0.000014 0.029498 0.995730 00:33
194 0.000010 0.028255 0.995730 00:33
195 0.000098 0.027213 0.995730 00:33
196 0.000031 0.024639 0.995730 00:33
197 0.000021 0.028268 0.995730 00:33
198 0.000005 0.021215 0.995730 00:33
199 0.000010 0.027356 0.995730 00:33

CAM 결과 확인_에폭 200

fig, ax = plt.subplots(5,5) 
k=0 
for i in range(5):
    for j in range(5): 
        x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
        camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
        a,b = net(x).tolist()[0]
        normalprob, pneumoniaprob = np.exp(a)/ (np.exp(a)+np.exp(b)) ,  np.exp(b)/ (np.exp(a)+np.exp(b)) 
        if normalprob>pneumoniaprob: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("normal(%s)" % normalprob.round(5))
        else: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("pneumonia(%s)" % pneumoniaprob.round(5))
        k=k+1 
fig.set_figwidth(16)            
fig.set_figheight(16)
fig.tight_layout()
fig, ax = plt.subplots(5,5) 
k=3000 
for i in range(5):
    for j in range(5): 
        x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
        camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
        a,b = net(x).tolist()[0]
        normalprob, pneumoniaprob = np.exp(a)/ (np.exp(a)+np.exp(b)) ,  np.exp(b)/ (np.exp(a)+np.exp(b)) 
        if normalprob>pneumoniaprob: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("normal(%s)" % normalprob.round(5))
        else: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("pneumonia(%s)" % pneumoniaprob.round(5))
        k=k+1 
fig.set_figwidth(16)            
fig.set_figheight(16)
fig.tight_layout()

SAMPLE

get_image_files(path)[0]
Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0766-0001.jpeg')
img = PILImage.create(get_image_files(path)[0])
img
x, = first(dls.test_dl([img]))  #이미지 텐서화
x.shape
torch.Size([1, 3, 224, 224])

판단 근거가 강할수록 파란색 $\to$ 보라색 변함

a=net(x).tolist()[0][0]
b=net(x).tolist()[0][1]
np.exp(a)/(np.exp(a)+np.exp(b)), np.exp(b)/(np.exp(a)+np.exp(b))
(0.9999999999931642, 6.835666941850839e-12)
camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
fig, (ax1,ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
test=camimg[0]-torch.min(camimg[0])
A1=torch.exp(-0.07*test)
A2=1-A1
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A2.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("MODE1 WEIGHT WITH THETA=0.07")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A1.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("MODE1 RES WEIGHT WITH THETA=0.07")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
  • $\theta$ 가 작아질수록 범위가 좁아지는? 경향
X1=np.array(A1.to("cpu").detach(),dtype=np.float32)
Y1=torch.Tensor(cv2.resize(X1,(224,224),interpolation=cv2.INTER_LINEAR))
x1=x.squeeze().to('cpu')*Y1
X12=np.array(A2.to("cpu").detach(),dtype=np.float32)
Y12=torch.Tensor(cv2.resize(X12,(224,224),interpolation=cv2.INTER_LINEAR))
x12=x.squeeze().to('cpu')*Y12
  • 1st CAM 결과를 분리하면 아래와 같음.
fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x12.squeeze().show(ax=ax1)  #MODE1
x1.squeeze().show(ax=ax2)  #MODE1_res
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
x1=x1.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
camimg1 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x1).squeeze())
  • CAM

    • mode1_res에 CAM 결과 올리기
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
x1.squeeze().show(ax=ax1)
ax1.imshow(camimg1[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
x1.squeeze().show(ax=ax2)
ax2.imshow(camimg1[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg1[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
a1=net(x1).tolist()[0][0]
b1=net(x1).tolist()[0][1]
np.exp(a1)/(np.exp(a1)+np.exp(b1)), np.exp(b1)/(np.exp(a1)+np.exp(b1))
(1.3657330880802881e-13, 0.9999999999998634)
  • 하나를 지우자 바로 판단 못함.. $\theta$를 변경하면 좀 달라질까?

$\theta=0.02$ 일 때 SAMPLE

AA1=torch.exp(-0.02*test)
AA2=1-AA1
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A2.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("MODE1 WEIGHT WITH THETA=0.02")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A1.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("MODE1 RES WEIGHT WITH THETA=0.02")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
XX1=np.array(AA1.to("cpu").detach(),dtype=np.float32)
YY1=torch.Tensor(cv2.resize(XX1,(224,224),interpolation=cv2.INTER_LINEAR))
xx1=x.squeeze().to('cpu')*YY1
XX12=np.array(AA2.to("cpu").detach(),dtype=np.float32)
YY12=torch.Tensor(cv2.resize(XX12,(224,224),interpolation=cv2.INTER_LINEAR))
xx12=x.squeeze().to('cpu')*YY12
  • 1st CAM 결과를 분리하면 아래와 같음.
fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
xx12.squeeze().show(ax=ax1)  #MODE1
xx1.squeeze().show(ax=ax2)  #MODE1_res
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
xx1=xx1.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
ver1 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(xx1).squeeze())
  • CAM
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
xx1.squeeze().show(ax=ax1)
ax1.imshow(ver1[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
xx1.squeeze().show(ax=ax2)
ax2.imshow(ver1[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(ver1[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
a1=net(x1).tolist()[0][0]
b1=net(x1).tolist()[0][1]
np.exp(a1)/(np.exp(a1)+np.exp(b1)), np.exp(b1)/(np.exp(a1)+np.exp(b1))
(1.3657330880802881e-13, 0.9999999999998634)
  • $\theta$가 작아지니 너무 똑같이 나온다.

$\theta=0.04$ 일 때 SAMPLE

AA1=torch.exp(-0.04*test)
AA2=1-AA1
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A2.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("MODE1 WEIGHT WITH THETA=0.04")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A1.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("MODE1 RES WEIGHT WITH THETA=0.04")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
XX1=np.array(AA1.to("cpu").detach(),dtype=np.float32)
YY1=torch.Tensor(cv2.resize(XX1,(224,224),interpolation=cv2.INTER_LINEAR))
xx1=x.squeeze().to('cpu')*YY1
XX12=np.array(AA2.to("cpu").detach(),dtype=np.float32)
YY12=torch.Tensor(cv2.resize(XX12,(224,224),interpolation=cv2.INTER_LINEAR))
xx12=x.squeeze().to('cpu')*YY12
  • 1st CAM 결과를 분리하면 아래와 같음.
fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
xx12.squeeze().show(ax=ax1)  #MODE1
xx1.squeeze().show(ax=ax2)  #MODE1_res
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
xx1=xx1.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
ver2 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(xx1).squeeze())
  • CAM
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
xx1.squeeze().show(ax=ax1)
ax1.imshow(ver2[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
xx1.squeeze().show(ax=ax2)
ax2.imshow(ver2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(ver2[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
a1=net(xx1).tolist()[0][0]
b1=net(xx1).tolist()[0][1]
np.exp(a1)/(np.exp(a1)+np.exp(b1)), np.exp(b1)/(np.exp(a1)+np.exp(b1))
(0.9903353645622871, 0.009664635437712963)
  • 이 정도면 괜찮은 것 같다.
  • 첫번째 CAM에서 정상 판단 근거였던 폐의 가운데 부분이 어두워지자 약간 오른쪽 폐로 이동한 모습.
  • CAT/DOG 예제에서 $\theta$를 2배씩 늘려나갔으나, 여기서는 $\theta$를 이전과 동일하게 유지함.
test1=ver2[0]-torch.min(ver2[0])
A3=torch.exp(-0.04*test1)  
A4=1-A3
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A3.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("MODE2 WEIGHT WITH THETA=0.04")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A4.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("MODE2 RES WEIGHT WITH THETA=0.04")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
X3=np.array(A3.to("cpu").detach(),dtype=np.float32)
Y3=torch.Tensor(cv2.resize(X3,(224,224),interpolation=cv2.INTER_LINEAR))
x3=x.squeeze().to('cpu')*YY1*Y3
X4=np.array(A4.to("cpu").detach(),dtype=np.float32)
Y4=torch.Tensor(cv2.resize(X4,(224,224),interpolation=cv2.INTER_LINEAR))
x4=x.squeeze().to('cpu')*YY1*Y4
  • 2nd CAM 결과를 분리하면 아래와 같음.
fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
xx12.squeeze().show(ax=ax1)  
xx1.squeeze().show(ax=ax2)  
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x4.squeeze().show(ax=ax1)  
x3.squeeze().show(ax=ax2)  
ax1.set_title("MODE2")
ax2.set_title("MODE2 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
x3=x3.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
ver22 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x3).squeeze())
  • CAM
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
x3.squeeze().show(ax=ax1)
ax1.imshow(ver22[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
x3.squeeze().show(ax=ax2)
ax2.imshow(ver22[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
fig, (ax1,ax2, ax3) = plt.subplots(1,3) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(ver2[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax3)
ax3.imshow(ver22[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax3.set_title("3RD CAM")
#

fig.set_figwidth(12)            
fig.set_figheight(12)
fig.tight_layout()
a2=net(x3).tolist()[0][0]
b2=net(x3).tolist()[0][1]
np.exp(a2)/(np.exp(a2)+np.exp(b2)), np.exp(b2)/(np.exp(a2)+np.exp(b2))
(7.881448771332073e-16, 0.9999999999999991)

  • 전체 그림에 적용하기 (n=11712)
fig, ax = plt.subplots(5,5) 
k=3000 
for i in range(5):
    for j in range(5): 
        x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
        camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
        a,b = net(x).tolist()[0]
        normalprob, pneumoniaprob = np.exp(a)/ (np.exp(a)+np.exp(b)) ,  np.exp(b)/ (np.exp(a)+np.exp(b)) 
        if normalprob>pneumoniaprob: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("normal(%s)" % normalprob.round(5))
        else: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("pneumonia(%s)" % pneumoniaprob.round(5))
        k=k+1 
fig.set_figwidth(16)            
fig.set_figheight(16)
fig.tight_layout()
import pandas as pd
col=pd.DataFrame()
k=0
col=[]
for k in range(5) :
    col[k]=print(k)
    k=k+1
0
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipykernel_170705/2694417875.py in <module>
      2 col=[]
      3 for k in range(5) :
----> 4     col[k]=print(k)
      5     k=k+1

IndexError: list assignment index out of range

interp = ClassificationInterpretation.from_learner(lrnr2)
interp.plot_confusion_matrix()
#cleaner   #잘못 예측한 이미지 제거_제거될 이미지를 보여주는 것 같음